{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b094a85-8595-4d4c-b9ff-9e5c2737c35e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pip install ipympl==0.9.4\n",
    "# %matplotlib widget"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b80dcf10-a028-47a1-8115-399d7eb14b0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from scipy.spatial.transform import Rotation as R\n",
    "from episode_storage import EpisodeReader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a57f2a29-7ae5-40c8-8a6d-56f7b4247f27",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_base_pose(base_pose1, base_pose2):\n",
    "    plt.figure()\n",
    "    for j, pose in enumerate([base_pose1, base_pose1]):\n",
    "        x, y, th = pose[:, 0], pose[:, 1], pose[:, 2]\n",
    "        plt.plot(x, y, label=['obs', 'action'][j])\n",
    "        for i in range(0, len(x), 2):  # Adjust the step to reduce or increase arrow density\n",
    "            plt.arrow(x=x[i], y=y[i] + 0.02 * j, dx=0.005 * np.cos(th[i]), dy=0.005 * np.sin(th[i]), head_width=0.0025)\n",
    "    plt.axis('equal')\n",
    "    plt.title('base_pose')\n",
    "    plt.legend()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1458682d-c1db-4246-9ce8-c610dd2e9085",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_3d_pose(pos1, quat1, pos2, quat2):\n",
    "    fig = plt.figure()\n",
    "    ax = fig.add_subplot(111, projection='3d')\n",
    "    for k, (pos, quat) in enumerate([(pos1, quat1), (pos2, quat2)]):\n",
    "        x, y, z = pos[:,0], pos[:,1], pos[:,2]\n",
    "        ax.plot(x, y, z, label=['obs', 'action'][k])\n",
    "        matrix = R.from_quat(quat).as_matrix()\n",
    "        for i in range(len(pos)):\n",
    "            if i % 5 == 0:\n",
    "                start = pos[i]\n",
    "                start[1] += 0.1 * k\n",
    "                for j in range(3):\n",
    "                    end = start + 0.02 * matrix[i, :, j]\n",
    "                    ax.quiver(*start, *(end - start), color=['r', 'g', 'b'][j], arrow_length_ratio=0.1)\n",
    "        ax.text(x[0], y[0], z[0], 'Start')\n",
    "        ax.text(x[-1], y[-1], z[-1], 'End')\n",
    "    ax.set_xlabel('x')\n",
    "    ax.set_ylabel('y')\n",
    "    ax.set_zlabel('z')\n",
    "    ax.set_aspect('equal')\n",
    "    plt.title('arm_pose')\n",
    "    plt.legend()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22a3674b-2ab9-4fac-a3d3-bbda7ddaa280",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_gripper_pos(gripper_pos1, gripper_pos2):\n",
    "    plt.figure()\n",
    "    plt.plot(gripper_pos1, label='obs')\n",
    "    plt.plot(gripper_pos2, label='action')\n",
    "    plt.title('gripper_pos')\n",
    "    plt.legend()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efc60f6c-2cf7-4b0c-8eba-c47882edfe62",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_episode(episode_dir):\n",
    "    reader = EpisodeReader(episode_dir)\n",
    "    observations = reader.observations\n",
    "    actions = reader.actions\n",
    "\n",
    "    # base_pose\n",
    "    plot_base_pose(np.array([obs['base_pose'] for obs in observations]), np.array([action['base_pose'] for action in actions]))\n",
    "\n",
    "    # arm_pos and arm_quat\n",
    "    plot_3d_pose(np.array([obs['arm_pos'] for obs in observations]), np.array([obs['arm_quat'] for obs in observations]),\n",
    "                 np.array([action['arm_pos'] for action in actions]), np.array([action['arm_quat'] for action in actions]))\n",
    "\n",
    "    # gripper_pos\n",
    "    plot_gripper_pos(np.array([obs['gripper_pos'] for obs in observations]), np.array([action['gripper_pos'] for action in actions]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c93942bb-fe19-4e86-b4ad-07e771bca77d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def main(input_dir):\n",
    "    episode_dirs = sorted([child for child in Path(input_dir).iterdir() if child.is_dir()])\n",
    "    for episode_dir in episode_dirs:\n",
    "        plot_episode(episode_dir)\n",
    "\n",
    "main('data/sim-v1')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5a84f42-aa0e-42b4-8e6b-dff23876554c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tidybot2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
